
// Copyright Epic Games, Inc. All Rights Reserved.

#include "vfsprovider.h"

#if ZEN_WITH_VFS

#	include <zencore/filesystem.h>
#	include <zencore/scopeguard.h>

#	include <system_error>

namespace zen {

bool
VfsTreeNode::IsDirectory() const
{
	return !GetChildren().empty() || FileSize == DirectoryMarker;
}

VfsTreeNode*
VfsTreeNode::FindNode(std::wstring_view NodeName)
{
	if (auto SlashIndex = NodeName.find_first_of('\\'); SlashIndex != std::string_view::npos)
	{
		std::wstring_view DirName = NodeName.substr(0, SlashIndex);

		if (VfsTreeNode* DirNode = FindNode(DirName))
		{
			return DirNode->FindNode(NodeName.substr(SlashIndex + 1));
		}
	}
	else
	{
		ExtendableWideStringBuilder<64> NodeNameZ;
		NodeNameZ.Append(NodeName);

		PopulateVirtualNode();

		for (auto& Child : GetChildren())
		{
			ExtendableWideStringBuilder<64> ChildName;
			Utf8ToWide(Child->Name, ChildName);

			if (PrjFileNameCompare(ChildName.c_str(), NodeNameZ.c_str()) == 0)
			{
				return Child.get();
			}
		}
	}

	return nullptr;
}

VfsTreeNode*
VfsTreeNode::FindNode(std::string_view NodeName)
{
	for (auto& Child : GetChildren())
	{
		if (NodeName == Child->Name)
		{
			return Child.get();
		}
	}

	return nullptr;
}

void
VfsTreeNode::PopulateVirtualNode()
{
	if (IsDirectory() && GetChildren().empty())
	{
		// Populate this node via the data source

		DataSource->PopulateDirectory(GetPath(), *this);
		NormalizeTree();
	}
}

void
VfsTreeNode::AddFileNode(std::string_view NodeName, uint64_t InFileSize, Oid InChunkId, Ref<VfsTreeDataSource> InDataSource)
{
	auto SlashIndex = NodeName.find_first_of('/');

	if (SlashIndex != std::string_view::npos)
	{
		const std::string_view DirName	 = NodeName.substr(0, SlashIndex);
		const std::string_view ChildPath = NodeName.substr(SlashIndex + 1);

		VfsTreeNode* DirNode = FindNode(DirName);

		if (!DirNode)
		{
			VfsTreeNode* NewNode = AddNewChildNode();
			NewNode->Name		 = std::string(DirName);

			DirNode = NewNode;
		}

		DirNode->AddFileNode(ChildPath, InFileSize, InChunkId, std::move(InDataSource));
	}
	else
	{
		auto InitLeafNode = [&](VfsTreeNode& NewNode) {
			NewNode.Name	   = std::string(NodeName);
			NewNode.FileSize   = InFileSize;
			NewNode.ChunkId	   = InChunkId;
			NewNode.DataSource = std::move(InDataSource);
		};

		for (auto& Child : GetChildren())
		{
			// Already present, just update properties

			if (Child->Name == NodeName)
			{
				InitLeafNode(*Child);

				return;
			}
		}

		InitLeafNode(*AddNewChildNode());
	}
}

void
VfsTreeNode::AddVirtualNode(std::string_view NodeName, Ref<VfsTreeDataSource>&& InDataSource)
{
	auto SlashIndex = NodeName.find_first_of('/');

	if (SlashIndex != std::string_view::npos)
	{
		const std::string_view DirName	 = NodeName.substr(0, SlashIndex);
		const std::string_view ChildPath = NodeName.substr(SlashIndex + 1);

		VfsTreeNode* DirNode = FindNode(DirName);

		if (!DirNode)
		{
			// New directory entry (non-virtual)

			DirNode			  = AddNewChildNode();
			DirNode->Name	  = std::string(DirName);
			DirNode->FileSize = DirectoryMarker;
		}

		DirNode->AddVirtualNode(ChildPath, std::move(InDataSource));
	}
	else
	{
		// Add "leaf" node

		VfsTreeNode* LeafNode = nullptr;

		for (auto& Child : GetChildren())
		{
			// Already present, just update properties

			if (Child->Name == NodeName)
			{
				LeafNode = Child.get();
				break;
			}
		}

		if (!LeafNode)
		{
			LeafNode = AddNewChildNode();
		}

		LeafNode->Name		 = std::string(NodeName);
		LeafNode->FileSize	 = DirectoryMarker;
		LeafNode->DataSource = std::move(InDataSource);
	}
}

static bool
FileNameLessThan(const std::unique_ptr<VfsTreeNode>& Lhs, const std::unique_ptr<VfsTreeNode>& Rhs)
{
	ExtendableWideStringBuilder<64> Name1, Name2;

	Utf8ToWide(Lhs->GetName(), Name1);
	Utf8ToWide(Rhs->GetName(), Name2);

	return PrjFileNameCompare(Name1.c_str(), Name2.c_str()) < 0;
}

void
VfsTreeNode::NormalizeTree()
{
	// Ensure entries are in the order in which ProjFS expects them. If we
	// don't do this then there will be spurious duplicates and things will
	// behave strangely

	std::sort(begin(ChildArray), end(ChildArray), FileNameLessThan);

	for (auto& Child : GetChildren())
	{
		Child->NormalizeTree();
	}
}

void
VfsTreeNode::GetPath(StringBuilderBase& Path) const
{
	if (ParentNode)
		ParentNode->GetPath(Path);

	if (Path.Size())
	{
		Path.Append("\\");
	}

	Path.Append(Name);
}

const std::string
VfsTreeNode::GetPath() const
{
	ExtendableStringBuilder<64> Path;
	GetPath(Path);
	return Path.ToString();
}

//////////////////////////////////////////////////////////////////////////

struct VfsTree::Impl
{
	void SetTree(VfsTreeNode&& Tree, VfsTreeDataSource& DataProvider)
	{
		*m_RootNode = std::move(Tree);
		m_RootNode->NormalizeTree();
		m_DataProvider = &DataProvider;
	}
	VfsTreeNode& GetTree() { return *m_RootNode; }

	void Initialize(VfsTreeDataSource& DataProvider) { m_DataProvider = &DataProvider; }

	void ReadChunkData(Oid ChunkId, void* DataBuffer, uint64_t ByteOffset, uint64_t Length)
	{
		ZEN_ASSERT(m_DataProvider);
		m_DataProvider->ReadChunkData(ChunkId, DataBuffer, ByteOffset, Length);
	}

	std::unique_ptr<VfsTreeNode> m_RootNode		= std::make_unique<VfsTreeNode>(nullptr);
	VfsTreeDataSource*			 m_DataProvider = nullptr;

	const VfsTreeNode* FindNode(std::wstring NodeName) { return m_RootNode->FindNode(NodeName); }
};

//////////////////////////////////////////////////////////////////////////

VfsTree::VfsTree() : m_Impl(new Impl)
{
}

VfsTree::~VfsTree()
{
	delete m_Impl;
}

void
VfsTree::Initialize(VfsTreeDataSource& DataProvider)
{
	m_Impl->Initialize(DataProvider);
}

void
VfsTree::SetTree(VfsTreeNode&& Tree, VfsTreeDataSource& DataProvider)
{
	m_Impl->SetTree(std::move(Tree), DataProvider);
}

VfsTreeNode&
VfsTree::GetTree()
{
	return m_Impl->GetTree();
}

VfsTreeNode*
VfsTree::FindNode(std::wstring NodeName)
{
	return m_Impl->m_RootNode->FindNode(NodeName);
}

void
VfsTree::ReadChunkData(Oid ChunkId, void* DataBuffer, uint64_t ByteOffset, uint64_t Length)
{
	m_Impl->ReadChunkData(ChunkId, DataBuffer, ByteOffset, Length);
}

//////////////////////////////////////////////////////////////////////////

static Guid
ToGuid(const GUID* GuidPtr)
{
	return *reinterpret_cast<const Guid*>(GuidPtr);
}

static Guid
PathToGuid(std::string_view Path)
{
	const IoHash Hash = IoHash::HashBuffer(Path.data(), Path.size());

	Guid Id;

	memcpy(&Id.A, &Hash.Hash[0], sizeof Id.A);
	memcpy(&Id.B, &Hash.Hash[4], sizeof Id.B);
	memcpy(&Id.C, &Hash.Hash[8], sizeof Id.C);
	memcpy(&Id.D, &Hash.Hash[12], sizeof Id.D);

	return Id;
}

void
VfsProvider::EnumerationState::RestartScan()
{
	Iterator	= Node->GetChildren().begin();
	EndIterator = Node->GetChildren().end();
}

bool
VfsProvider::EnumerationState::MatchesPattern(PCWSTR SearchExpression)
{
	return PrjFileNameMatch(GetCurrentFileName(), SearchExpression);
}

void
VfsProvider::EnumerationState::MoveToNext()
{
	if (Iterator != EndIterator)
		++Iterator;
}

const WCHAR*
VfsProvider::EnumerationState::GetCurrentFileName()
{
	CurrentName = zen::Utf8ToWide((*Iterator)->GetName());
	return CurrentName.c_str();
}

PRJ_FILE_BASIC_INFO&
VfsProvider::EnumerationState::GetCurrentBasicInfo()
{
	if ((*Iterator)->IsDirectory())
	{
		BasicInfo = {.IsDirectory = true, .FileAttributes = FILE_ATTRIBUTE_DIRECTORY};
	}
	else
	{
		BasicInfo = {.IsDirectory = false, .FileSize = INT64((*Iterator)->GetFileSize()), .FileAttributes = FILE_ATTRIBUTE_READONLY};
	}

	return BasicInfo;
}

//////////////////////////////////////////////////////////////////////////

VfsProvider::VfsProvider(std::string_view RootPath) : m_RootPath(RootPath)
{
}

VfsProvider::~VfsProvider()
{
}

void
VfsProvider::Initialize()
{
	std::filesystem::path Root{m_RootPath};
	std::filesystem::path ManifestPath = Root / ".zen_vfs";
	bool				  HaveManifest = false;

	if (std::filesystem::exists(Root))
	{
		if (!std::filesystem::is_directory(Root))
		{
			throw std::runtime_error("specified VFS root exists but is not a directory");
		}

		std::error_code Ec;
		m_RootPath = WideToUtf8(CanonicalPath(Root, Ec).c_str());

		if (Ec)
		{
			throw std::system_error(Ec);
		}

		if (std::filesystem::exists(ManifestPath))
		{
			FileContents ManifestData = zen::ReadFile(ManifestPath);
			CbObject	 Manifest	  = LoadCompactBinaryObject(ManifestData.Flatten());

			m_RootGuid = Manifest["id"].AsUuid();

			HaveManifest = true;
		}
		else
		{
			m_RootGuid = PathToGuid(m_RootPath);
		}
	}
	else
	{
		zen::CreateDirectories(Root);
		m_RootGuid = PathToGuid(m_RootPath);
	}

	if (!HaveManifest)
	{
		CbObjectWriter Cbo;
		Cbo << "id" << m_RootGuid;
		Cbo << "path" << m_RootPath;
		CbObject Manifest = Cbo.Save();

		zen::WriteFile(ManifestPath, Manifest.GetBuffer().AsIoBuffer());
	}

	m_IsInitialized.test_and_set();

	// Set up ProjFS root placeholder

	ExtendableWideStringBuilder<256> WideRootPath;
	zen::Utf8ToWide(m_RootPath, WideRootPath);

	PRJ_PLACEHOLDER_VERSION_INFO VersionInfo{};

	if (HRESULT hRes = PrjMarkDirectoryAsPlaceholder(WideRootPath.c_str(), NULL, &VersionInfo, reinterpret_cast<const GUID*>(&m_RootGuid));
		FAILED(hRes))
	{
		ThrowSystemException(hRes, fmt::format("PrjMarkDirectoryAsPlaceholder call for '{}' failed", m_RootPath));
	}
}

void
VfsProvider::Cleanup()
{
	if (m_IsInitialized.test())
	{
		zen::CleanDirectoryExceptDotFiles(m_RootPath);

		m_IsInitialized.clear();
	}
}

void
VfsProvider::AddMount(std::string_view Mountpoint, Ref<VfsTreeDataSource>&& DataSource)
{
	ZEN_ASSERT(m_IsInitialized.test());

	if (!m_Tree)
	{
		m_Tree = new VfsTree;
	}

	auto& Tree = m_Tree->GetTree();

	if (Tree.FindNode(Mountpoint))
	{
		throw std::runtime_error(fmt::format("mount point '{}' already exists", Mountpoint));
	}
	else
	{
		Tree.AddVirtualNode(Mountpoint, std::move(DataSource));
		Tree.NormalizeTree();
	}
}

void
VfsProvider::Run()
{
	const PRJ_CALLBACKS Callbacks = {.StartDirectoryEnumerationCallback	   = ProjFsProviderInterface::StartDirEnumCallback_C,
									 .EndDirectoryEnumerationCallback	   = ProjFsProviderInterface::EndDirEnumCallback_C,
									 .GetDirectoryEnumerationCallback	   = ProjFsProviderInterface::GetDirEnumCallback_C,
									 .GetPlaceholderInfoCallback		   = ProjFsProviderInterface::GetPlaceholderInfoCallback_C,
									 .GetFileDataCallback				   = ProjFsProviderInterface::GetFileDataCallback_C,
									 /* optional */ .QueryFileNameCallback = ProjFsProviderInterface::QueryFileName_C,
									 /* optional */ .NotificationCallback  = ProjFsProviderInterface::NotificationCallback_C,
									 /* optional */ .CancelCommandCallback = nullptr /* ProjFsProviderInterface::CancelCommand_C */};

	PRJ_NOTIFICATION_MAPPING Notifications[] = {
		{.NotificationBitMask = PRJ_NOTIFY_FILE_OPENED | PRJ_NOTIFY_PRE_RENAME | PRJ_NOTIFY_PRE_DELETE, .NotificationRoot = L""}};

	const PRJ_STARTVIRTUALIZING_OPTIONS VirtOptions = {.Flags					  = PRJ_FLAG_NONE /* PRJ_FLAG_USE_NEGATIVE_PATH_CACHE */,
													   .PoolThreadCount			  = 1,
													   .ConcurrentThreadCount	  = 1,
													   .NotificationMappings	  = Notifications,
													   .NotificationMappingsCount = ZEN_ARRAY_COUNT(Notifications)};

	PRJ_NAMESPACE_VIRTUALIZATION_CONTEXT VirtContext{};

	ExtendableWideStringBuilder<256> WideRootPath;
	zen::Utf8ToWide(m_RootPath, WideRootPath);

	if (HRESULT hRes = PrjStartVirtualizing(WideRootPath.c_str(), &Callbacks, this /* InstanceContext */, &VirtOptions, &VirtContext);
		FAILED(hRes))
	{
		ThrowSystemException(hRes, "PrjStartVirtualizing failed");
	}

	auto StopVirtualizing = MakeGuard([&] {
		PrjStopVirtualizing(VirtContext);
		m_IsRunning.clear();
	});

	m_IsRunning.test_and_set();

	while (!m_StopRunningEvent.Wait(10000 /* ms */))
	{
	}
}

void
VfsProvider::RequestStop()
{
	m_StopRunningEvent.Set();
}

//////////////////////////////////////////////////////////////////////////

HRESULT
VfsProvider::StartDirEnum(_In_ const PRJ_CALLBACK_DATA* CallbackData, _In_ const GUID* EnumerationId)
{
	if (m_IsRunning.test() == false)
	{
		return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
	}

	if (!m_Tree)
	{
		return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
	}

	const std::string Path	 = WideToUtf8(CallbackData->FilePathName);
	const Guid		  EnumId = ToGuid(EnumerationId);

	VfsTreeNode* EnumNode = &m_Tree->GetTree();

	if (!Path.empty())
	{
		EnumNode = EnumNode->FindNode(CallbackData->FilePathName);
	}

	if (!EnumNode)
	{
		return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
	}

	EnumNode->PopulateVirtualNode();

	auto NodeIt = EnumNode->GetChildren().begin();
	auto EndIt	= EnumNode->GetChildren().end();

	RwLock::ExclusiveLockScope _(m_EnumerationsLock);
	m_ActiveEnumerations[EnumId] = EnumerationState{.Path = Path, .Node = EnumNode, .Iterator = NodeIt, .EndIterator = EndIt};

	return S_OK;
}

HRESULT
VfsProvider::EndDirEnum(_In_ const PRJ_CALLBACK_DATA* CallbackData, _In_ const GUID* EnumerationId)
{
	ZEN_UNUSED(CallbackData);

	const Guid EnumId = ToGuid(EnumerationId);

	RwLock::ExclusiveLockScope _(m_EnumerationsLock);
	m_ActiveEnumerations.erase(EnumId);

	return S_OK;
}

HRESULT
VfsProvider::GetDirEnum(_In_ const PRJ_CALLBACK_DATA*	 CallbackData,
						_In_ const GUID*				 EnumerationId,
						_In_opt_ PCWSTR					 SearchExpression,
						_In_ PRJ_DIR_ENTRY_BUFFER_HANDLE DirEntryBufferHandle)
{
	if (m_IsRunning.test() == false)
	{
		return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
	}

	const Guid EnumId = ToGuid(EnumerationId);

	EnumerationState* State = nullptr;

	{
		RwLock::ExclusiveLockScope _(m_EnumerationsLock);

		if (auto EnumIt = m_ActiveEnumerations.find(EnumId); EnumIt == m_ActiveEnumerations.end())
		{
			return E_INVALIDARG;
		}
		else
		{
			State = &EnumIt->second;
		}
	}

	if (CallbackData->Flags & PRJ_CB_DATA_FLAG_ENUM_RESTART_SCAN)
	{
		State->RestartScan();
	}

	bool FilledAtLeastOne = false;

	while (State->IsCurrentValid())
	{
		if (State->MatchesPattern(SearchExpression))
		{
			HRESULT hRes = PrjFillDirEntryBuffer(State->GetCurrentFileName(), &State->GetCurrentBasicInfo(), DirEntryBufferHandle);

			if (hRes != S_OK)
			{
				if (FilledAtLeastOne)
				{
					return S_OK;
				}
				else
				{
					return hRes;
				}
			}

			FilledAtLeastOne = true;
		}

		State->MoveToNext();
	}

	return S_OK;
}

HRESULT
VfsProvider::GetPlaceholderInfo(_In_ const PRJ_CALLBACK_DATA* CallbackData)
{
	if (m_IsRunning.test() == false)
	{
		return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
	}

	PRJ_PLACEHOLDER_INFO PlaceholderInfo;
	memset(&PlaceholderInfo, 0, sizeof PlaceholderInfo);

	std::wstring Path = CallbackData->FilePathName;

	if (const VfsTreeNode* FoundNode = m_Tree->FindNode(Path))
	{
		if (FoundNode->IsDirectory())
		{
			PlaceholderInfo.FileBasicInfo = {.IsDirectory	 = true,
											 .FileAttributes = FILE_ATTRIBUTE_DIRECTORY | FILE_ATTRIBUTE_NOT_CONTENT_INDEXED};
		}
		else
		{
			PlaceholderInfo.FileBasicInfo = {.IsDirectory	 = false,
											 .FileSize		 = static_cast<INT64>(FoundNode->GetFileSize()),
											 .FileAttributes = FILE_ATTRIBUTE_READONLY | FILE_ATTRIBUTE_NOT_CONTENT_INDEXED};
		}

		// NOTE: for correctness, we should actually provide the canonical casing for the name
		// as stored in our backing store rather than simply providing the name which ProjFS
		// used for the lookup!

		HRESULT hRes = PrjWritePlaceholderInfo(CallbackData->NamespaceVirtualizationContext,
											   CallbackData->FilePathName,
											   &PlaceholderInfo,
											   sizeof PlaceholderInfo);

		if (FAILED(hRes))
		{
			return hRes;
		}

		return S_OK;
	}
	else
	{
		return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
	}
}

HRESULT
VfsProvider::GetFileData(_In_ const PRJ_CALLBACK_DATA* CallbackData, _In_ UINT64 ByteOffset, _In_ UINT32 Length)
{
	if (m_IsRunning.test() == false)
	{
		return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
	}

	const VfsTreeNode* FileNode = m_Tree->FindNode(CallbackData->FilePathName);

	if (!FileNode || FileNode->IsDirectory())
	{
		return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
	}

	void* DataBuffer = PrjAllocateAlignedBuffer(CallbackData->NamespaceVirtualizationContext, Length);
	auto  _			 = MakeGuard([&] { PrjFreeAlignedBuffer(DataBuffer); });

	if (auto Source = FileNode->GetDataSource())
	{
		if (const Oid ChunkId = FileNode->GetChunkId(); ChunkId != Oid::Zero)
		{
			Source->ReadChunkData(ChunkId, DataBuffer, ByteOffset, Length);
		}
		else
		{
			Source->ReadNamedData(FileNode->GetName(), DataBuffer, ByteOffset, Length);
		}
	}
	else
	{
		return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
	}

	HRESULT hRes =
		PrjWriteFileData(CallbackData->NamespaceVirtualizationContext, &CallbackData->DataStreamId, DataBuffer, ByteOffset, Length);

	if (FAILED(hRes))
		return hRes;

	return S_OK;
}

HRESULT
VfsProvider::Notify(_In_ const PRJ_CALLBACK_DATA*		 CallbackData,
					_In_ BOOLEAN						 IsDirectory,
					_In_ PRJ_NOTIFICATION				 NotificationType,
					_In_opt_ PCWSTR						 DestinationFileName,
					_Inout_ PRJ_NOTIFICATION_PARAMETERS* NotificationParameters)
{
	ZEN_UNUSED(CallbackData, IsDirectory, DestinationFileName, NotificationParameters);

	switch (NotificationType)
	{
		case PRJ_NOTIFICATION_PRE_RENAME:
			return HRESULT_FROM_WIN32(ERROR_ACCESS_DENIED);

		case PRJ_NOTIFICATION_PRE_DELETE:
			if (m_IsRunning.test())
			{
				// don't allow deletes while virtualizing, it doesn't seem
				// to make sense as the operation won't be reflected into
				// the underlying data source
				return HRESULT_FROM_NT(0xC0000121L /* STATUS_CANNOT_DELETE */);
			}
			else
			{
				// file deletion is ok during cleanup
				return S_OK;
			}

		default:
			break;
	}

	return S_OK;
}

/** This callback is optional. If the provider does not supply an implementation of this callback, ProjFS will invoke the provider's
	directory enumeration callbacks to determine the existence of a file path in the provider's store.

	The provider should use PrjFileNameCompare as the comparison routine when searching its backing store for the specified file.
 */

HRESULT
VfsProvider::QueryFileName(_In_ const PRJ_CALLBACK_DATA* CallbackData)
{
	if (m_IsRunning.test() == false)
	{
		return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
	}

	if (const VfsTreeNode* FileNode = m_Tree->FindNode(CallbackData->FilePathName))
	{
		return S_OK;
	}

	return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
}

void
VfsProvider::CancelCommand(_In_ const PRJ_CALLBACK_DATA* CallbackData)
{
	ZEN_UNUSED(CallbackData);
}

}  // namespace zen

#endif
